Skip to content

[P/D disagg] - support decode side radix cache#19746

Open
ishandhanani wants to merge 63 commits intomainfrom
ishan/add-radix-cache-decode
Open

[P/D disagg] - support decode side radix cache#19746
ishandhanani wants to merge 63 commits intomainfrom
ishan/add-radix-cache-decode

Conversation

@ishandhanani
Copy link
Copy Markdown
Collaborator

@ishandhanani ishandhanani commented Mar 3, 2026

Summary

In PD disaggregation, the decode worker can now use radix cache to reuse shared prefixes and request only the delta KV from prefill instead of transferring the full prefix on every turn.

This is enabled with --disaggregation-decode-enable-radix-cache on the decode server.

For now, this path is supported only with --disaggregation-transfer-backend nixl. server_args.py now rejects other transfer backends early when the decode radix cache flag is enabled. Mooncake support will follow in a separate PR.

Main Changes

  • Decode scheduler
    • Match incoming requests against the decode-side radix tree.
    • Lock matched prefix nodes for the request lifetime.
    • Pre-allocate only the delta KV pages beyond the matched prefix.
  • Decode -> prefill protocol
    • Plumb decode_prefix_len from decode to prefill for the NIXL path.
    • Allow full-prefix hits where decode may need no KV pages transferred.
  • Prefill transfer path
    • Initialize the sender with only the unsent delta pages.
    • Keep the chunked transfer cursor monotonic when decode already has part of the prefix.
    • Skip empty non-last chunks so the sender/receiver chunk protocol stays consistent.
  • Correctness / cleanup
    • Align matched prefix length to page boundaries for paged KV allocators.
    • Guard lock release / cleanup paths for transfer-failure cases.
    • Batch finished prebuilt frees through the free-group path.
  • CLI / config
    • The user-facing switch is --disaggregation-decode-enable-radix-cache.
    • Current validation requires --disaggregation-transfer-backend nixl when that flag is set.

Interface

Enable decode radix cache on the decode worker with:

--disaggregation-mode decode --disaggregation-transfer-backend nixl --disaggregation-decode-enable-radix-cache

Prefill continues to run with --disaggregation-transfer-backend nixl.

Note: DP attention is still experimental here. The flag is allowed, but good cache hit rates require prefix-aware DP routing.

Benchmark

Setup

  • Hardware: 1x NVIDIA B200 node (8 GPUs), single-node PD disaggregation via NIXL
  • Model: Qwen/Qwen3-32B, FP8 KV cache, 3P1D, TP=2 per worker
  • Workload: 20 unique ~50K-token prefixes + ~4.5K suffix (~91% prefix reuse), 1000 requests, concurrency 128

Results

Metric Baseline Decode Radix Cache Improvement
Request throughput (req/s) 1.21 1.59 1.32x
Output token throughput (tok/s) 430 566 1.32x
TTFT p50 (s) 73.2 9.0 8.1x
TTFT avg (s) 77.7 31.6 2.5x
Request latency p50 (s) 99.1 73.4 1.35x
ITL avg (ms) 65.6 130.6 0.50x
Benchmark duration (s) 827 628 1.32x

Decode-side logs show the reason for the throughput gain: baseline decode ran near KV capacity (token_usage ~ 0.99) and only fit ~37 running requests, while decode radix cache reduced duplicate prefix residency (token_usage ~ 0.75) and fit roughly 104-126 running requests. The ITL regression is expected from the larger decode batch.

Test Plan

  • Qwen3-0.6B local PD disagg sanity runs
  • MiniMax-M2.5 1P1D on B200
  • Qwen3-32B 3P1D on B200 (results above)
  • Guard decode radix cache behind nixl in server_args.py
  • Multi-node cross-host testing (RDMA transport)
  • Mooncake transfer backend support (separate PR)

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ishandhanani ishandhanani changed the title [Draft] [P/D disagg] - support decode side radix cache [P/D disagg] - support decode side radix cache Mar 3, 2026
@dongyibo
Copy link
Copy Markdown

dongyibo commented Mar 3, 2026

@ishandhanani Can this feature be understood as follows:
In a multi-turn dialogue scenario, the first round takes tokens 1, 2, 3 as input and outputs tokens 4, 5, 6.
The second round takes tokens 1, 2, 3, 4, 5, 6, 7, 8, 9 as input and outputs tokens 10, 11, 12.

Current status of pd-disagg:
In the first round, for the decode worker, the generated tokens 4, 5, 6 are not cached, and the KV cache of the input tokens 1, 2, 3 is not saved.
In the second round, the prefill worker needs to send the KV cache for all tokens 1, 2, 3, 4, 5, 6, 7, 8, and 9 to the decode worker.

Based on this PR's implementation:
In the first round, the decode worker saves the KV cache for tokens 1, 2, 3, 4, 5, and 6.
In the second round, the prefill worker only needs to send the KV cache for tokens 7, 8, and 9 to the decode worker.
Is my understanding correct?

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 3, 2026

@ishandhanani Can this feature be understood as follows: In a multi-turn dialogue scenario, the first round takes tokens 1, 2, 3 as input and outputs tokens 4, 5, 6. The second round takes tokens 1, 2, 3, 4, 5, 6, 7, 8, 9 as input and outputs tokens 10, 11, 12.

Current status of pd-disagg: In the first round, for the decode worker, the generated tokens 4, 5, 6 are not cached, and the KV cache of the input tokens 1, 2, 3 is not saved. In the second round, the prefill worker needs to send the KV cache for all tokens 1, 2, 3, 4, 5, 6, 7, 8, and 9 to the decode worker.

Based on this PR's implementation: In the first round, the decode worker saves the KV cache for tokens 1, 2, 3, 4, 5, and 6. In the second round, the prefill worker only needs to send the KV cache for tokens 7, 8, and 9 to the decode worker. Is my understanding correct?

Yep. This is correct

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

/gemini review

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

- Set req.prefix_indices in _pre_alloc so init_next_round_input(None)
  computes extend_input_len correctly from the cached prefix length.
  Without this, prepare_for_prebuilt runs a full-length extend instead
  of a delta extend.

- Always call inc_lock_ref on the matched node (even on empty match)
  to match aggregated scheduler behavior. Prevents lock_ref underflow
  when cache_finished_req unconditionally calls dec_lock_ref.
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 4, 2026

Next step is testing with a larger model on B200. And then step after (maybe in follow up) is to do the same for mooncake

Comment thread python/sglang/srt/disaggregation/prefill.py Outdated
@dongyibo
Copy link
Copy Markdown

dongyibo commented Mar 4, 2026

@ishandhanani There seems to be a constraint here:
For multiple decode workers, such as when decode is run with DP, it's best if the same DP rank is used for the entire conversation; otherwise, the cached KV cache cannot be utilized?

@nananall
Copy link
Copy Markdown

nananall commented Mar 4, 2026

Could you share the exact command you used to run this? I'd like to reproduce it and test it on my side.

@ishandhanani
Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Mar 4, 2026

@ishandhanani There seems to be a constraint here: For multiple decode workers, such as when decode is run with DP, it's best if the same DP rank is used for the entire conversation; otherwise, the cached KV cache cannot be utilized?

Theres a few things here.

  1. when running with multiple decode workers (standard data parallelism of workers) - I expect the router to pick the right decode worker based on kv load. The dynamo router handles this very well + performantly out of the box
  2. For DP attention - agreed. Right now I have not added support. Need to do this

Comment thread python/sglang/srt/models/qwen3.py Outdated
Comment on lines +555 to +567
need_poll = len(self.queue) > 0 and not all(
decode_req.waiting_for_input for decode_req in self.queue
)
# All TPs must agree on whether to poll and on queue size, otherwise
# poll_and_all_reduce (which sizes its tensor by queue length) hangs.
if dist.get_world_size(self.gloo_group) > 1:
n = len(self.queue)
local = torch.tensor(
[int(need_poll), n, -n], dtype=torch.int64, device="cpu"
)
dist.all_reduce(local, op=dist.ReduceOp.MIN, group=self.gloo_group)
if local[0].item() == 0 or local[1].item() != -local[2].item():
return
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this part? I think this issue should be resolved by #21299

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rechecked this after the main merge. We are keeping the all-reduce guard intentionally because it came from #22234 and protects poll_and_all_reduce from transient TP queue-size divergence. Restored in b24f58d07.

Comment thread python/sglang/srt/disaggregation/decode.py Outdated
Comment on lines +785 to +789
allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens,
count_retracted=True,
extra_reserved_reqs=len(preallocated_reqs) + 1,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these lines take effect? I didn't find where allocatable_tokens is used below in pop_preallocated.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It affects the next iteration of the pop_preallocated loop. I kept the recompute and tightened the comment to state that it refreshes the budget for the next queue entry after page rounding and newly locked evictable cache state.

Comment thread python/sglang/srt/disaggregation/decode.py Outdated
Comment on lines 1012 to +1026
if self.scheduler.enable_hisparse:
# Direct-to-host path: only allocate logical indices (no hisparse
# device indices) and allocate host indices for RDMA destination.
coordinator = self.scheduler.hisparse_coordinator
device = self.token_to_kv_pool_allocator.device
last_loc = (
prefix_indices[-1:].to(dtype=torch.int64, device=device)
if prefix_len > 0
else torch.tensor([-1], dtype=torch.int64, device=device)
)
kv_loc = self.token_to_kv_pool_allocator.alloc_logical_only(
prefix_lens=torch.tensor([0], dtype=torch.int64, device=device),
prefix_lens_cpu=torch.tensor([0], dtype=torch.int64),
prefix_lens=torch.tensor(
[prefix_len], dtype=torch.int64, device=device
),
prefix_lens_cpu=torch.tensor([prefix_len], dtype=torch.int64),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC: @hzh0425

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I merged main's newer HiSparse admission/logical-pool accounting and gated the decode-radix path so HiSparse stays effectively on the upstream/flag-off behavior. Decode radix + HiSparse remains rejected in server args.

Copy link
Copy Markdown
Collaborator

@hzh0425 hzh0425 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't actually need to modify HiSparse here, because HiSparse is incompatible with the L1 RadixTree; we can just roll back the code.
@ishandhanani

Comment on lines +1236 to +1247
# All TPs must agree on queue size before poll_and_all_reduce.
# _resolve_pending_reqs does independent HTTP calls per TP, so queue
# sizes can transiently diverge; a mismatched all_reduce corrupts gloo.
if dist.get_world_size(self.gloo_group) > 1:
n = len(self.queue)
local = torch.tensor([n, -n], dtype=torch.int64, device="cpu")
dist.all_reduce(local, op=dist.ReduceOp.MIN, group=self.gloo_group)
if local[0].item() != -local[1].item():
return []
if local[0].item() == 0:
return []
elif not self.queue:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These can be reverted as well after #21299

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rechecked this after the main merge. We are keeping the all-reduce queue-size guard intentionally because it came from #22234 and protects poll_and_all_reduce from transient TP queue-size divergence. Restored in b24f58d07.

Comment on lines +1482 to +1483
if req.kv_committed_len is not None:
req.fill_ids = req.fill_ids[: req.kv_committed_len]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call req.set_extend_input_len here?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. After truncating req.fill_ids to kv_committed_len, we now call req.set_extend_input_len(len(req.fill_ids) - len(req.prefix_indices)).

Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just finished another round of review of this PR, @cctry could you check whether these comments are considerable?

Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Apr 14, 2026

Benchmark update for the latest decode-radix-cache-0413.sqsh

Current apples-to-apples pair:

  • Both completed successfully with 1000 requests, concurrency 128, 50K prefix + 4.5K suffix.
  • OSL matched: baseline avg 355.400 output tokens/request, radix avg 355.396 output tokens/request.
  • No leaks
Metric Baseline 12486 Radix 12485 Delta
Request throughput 1.2433 req/s 1.6075 req/s +29.29%
Output token throughput 441.87 tok/s 571.30 tok/s +29.29%
Total token throughput 68247.77 tok/s 88239.17 tok/s +29.29%
TTFT p50 71940.70 ms 6993.33 ms -90.28%
Request latency p50 98049.79 ms 70669.68 ms -27.93%
ITL avg 65.17 ms 134.63 ms radix worse, +106.58%
E2E duration 804.31 s 622.08 s -22.66%

Main read: with matched OSL, decode radix gives a clean 1.29x request/output throughput improvement and ~10.3x better TTFT p50. ITL is worse, but end-to-end request latency and total duration are better.

Commands for baseline 12486

AIPerf:

aiperf profile --model 'Qwen/Qwen3-32B' --url 'http://gpu-3:8000' --endpoint-type 'chat' --tokenizer '/fsw-home/qwen32b' --max-workers 16 --streaming --ui-type None --artifact-dir '/scratch/fsw/ishan/ignition/outputs/12486/results/sweep_000_prefix_isl=50000_suffix_isl=4500' --request-timeout-seconds 10800 --synthetic-input-tokens-mean 4500 --synthetic-input-tokens-stddev 500 --prefix-prompt-length 50000 --num-prefix-prompts 20 --output-tokens-mean 350 --output-tokens-stddev 100 --num-dataset-entries 1000 --random-seed 42 --concurrency 128 --request-count 1000 --export-level 'summary' --no-gpu-telemetry --no-server-metrics --extra-inputs '{"ignore_eos":true}'

Server commands extracted from the srun launch lines:

# prefill 0
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/prefill_config_endpoint_0_node_gpu-3_12486.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10000 --disaggregation-bootstrap-port 13000

# prefill 1
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/prefill_config_endpoint_1_node_gpu-3_12486.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10100 --disaggregation-bootstrap-port 13100

# prefill 2
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/prefill_config_endpoint_2_node_gpu-3_12486.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10200 --disaggregation-bootstrap-port 13200

# decode 0
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/decode_config_endpoint_0_node_gpu-3_12486.json --enable-metrics --disaggregation-mode decode --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size -1 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --load-balance-method total_tokens --prefill-round-robin-balance --num-reserved-decode-tokens 400 --port 11000 --disaggregation-bootstrap-port 14000

# frontend
python3 -m dynamo.frontend --http-port 8001

# nginx
nginx -c /scratch/fsw/ishan/ignition/outputs/12486/logs/nginx.conf -g 'daemon off;'
Commands for radix 12485

AIPerf:

aiperf profile --model 'Qwen/Qwen3-32B' --url 'http://gpu-1:8000' --endpoint-type 'chat' --tokenizer '/fsw-home/qwen32b' --max-workers 16 --streaming --ui-type None --artifact-dir '/scratch/fsw/ishan/ignition/outputs/12485/results/sweep_000_prefix_isl=50000_suffix_isl=4500' --request-timeout-seconds 10800 --synthetic-input-tokens-mean 4500 --synthetic-input-tokens-stddev 500 --prefix-prompt-length 50000 --num-prefix-prompts 20 --output-tokens-mean 350 --output-tokens-stddev 100 --num-dataset-entries 1000 --random-seed 42 --concurrency 128 --request-count 1000 --export-level 'summary' --no-gpu-telemetry --no-server-metrics --extra-inputs '{"ignore_eos":true}'

Server commands extracted from the srun launch lines:

# prefill 0
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12485/logs/prefill_config_endpoint_0_node_gpu-1_12485.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10000 --disaggregation-bootstrap-port 13000

# prefill 1
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12485/logs/prefill_config_endpoint_1_node_gpu-1_12485.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10100 --disaggregation-bootstrap-port 13100

# prefill 2
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12485/logs/prefill_config_endpoint_2_node_gpu-1_12485.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10200 --disaggregation-bootstrap-port 13200

# decode 0
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12485/logs/decode_config_endpoint_0_node_gpu-1_12485.json --enable-metrics --disaggregation-mode decode --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size -1 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --load-balance-method total_tokens --prefill-round-robin-balance --num-reserved-decode-tokens 400 --disaggregation-decode-enable-radix-cache --port 11000 --disaggregation-bootstrap-port 14000

# frontend
python3 -m dynamo.frontend --http-port 8001

# nginx
nginx -c /scratch/fsw/ishan/ignition/outputs/12485/logs/nginx.conf -g 'daemon off;'

@yudian0504
Copy link
Copy Markdown
Contributor

Benchmark update for the latest decode-radix-cache-0413.sqsh

Current apples-to-apples pair:

  • Both completed successfully with 1000 requests, concurrency 128, 50K prefix + 4.5K suffix.
  • OSL matched: baseline avg 355.400 output tokens/request, radix avg 355.396 output tokens/request.
  • No leaks

Metric Baseline 12486 Radix 12485 Delta
Request throughput 1.2433 req/s 1.6075 req/s +29.29%
Output token throughput 441.87 tok/s 571.30 tok/s +29.29%
Total token throughput 68247.77 tok/s 88239.17 tok/s +29.29%
TTFT p50 71940.70 ms 6993.33 ms -90.28%
Request latency p50 98049.79 ms 70669.68 ms -27.93%
ITL avg 65.17 ms 134.63 ms radix worse, +106.58%
E2E duration 804.31 s 622.08 s -22.66%
Main read: with matched OSL, decode radix gives a clean 1.29x request/output throughput improvement and ~10.3x better TTFT p50. ITL is worse, but end-to-end request latency and total duration are better.

Commands for baseline 12486
AIPerf:

aiperf profile --model 'Qwen/Qwen3-32B' --url 'http://gpu-3:8000' --endpoint-type 'chat' --tokenizer '/fsw-home/qwen32b' --max-workers 16 --streaming --ui-type None --artifact-dir '/scratch/fsw/ishan/ignition/outputs/12486/results/sweep_000_prefix_isl=50000_suffix_isl=4500' --request-timeout-seconds 10800 --synthetic-input-tokens-mean 4500 --synthetic-input-tokens-stddev 500 --prefix-prompt-length 50000 --num-prefix-prompts 20 --output-tokens-mean 350 --output-tokens-stddev 100 --num-dataset-entries 1000 --random-seed 42 --concurrency 128 --request-count 1000 --export-level 'summary' --no-gpu-telemetry --no-server-metrics --extra-inputs '{"ignore_eos":true}'

Server commands extracted from the srun launch lines:

# prefill 0
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/prefill_config_endpoint_0_node_gpu-3_12486.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10000 --disaggregation-bootstrap-port 13000

# prefill 1
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/prefill_config_endpoint_1_node_gpu-3_12486.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10100 --disaggregation-bootstrap-port 13100

# prefill 2
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/prefill_config_endpoint_2_node_gpu-3_12486.json --enable-metrics --disaggregation-mode prefill --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size 32768 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --port 10200 --disaggregation-bootstrap-port 13200

# decode 0
python3 -m dynamo.sglang --model-path /scratch/fsw/ishan/qwen32b --served-model-name Qwen/Qwen3-32B --host 0.0.0.0 --dump-config-to /scratch/fsw/ishan/ignition/outputs/12486/logs/decode_config_endpoint_0_node_gpu-3_12486.json --enable-metrics --disaggregation-mode decode --trust-remote-code --kv-cache-dtype fp8_e4m3 --attention-backend flashinfer --context-length 131072 --disaggregation-transfer-backend nixl --enable-symm-mem --enable-single-batch-overlap --max-prefill-tokens 32768 --scheduler-recv-interval 1 --stream-interval 30 --watchdog-timeout 1000000 --log-level debug --page-size 64 --json-model-override-args '{"rope_scaling":{"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768},"max_position_embeddings":131072}' --tensor-parallel-size 2 --chunked-prefill-size -1 --mem-fraction-static 0.90 --cuda-graph-max-bs 256 --max-running-requests 256 --load-balance-method total_tokens --prefill-round-robin-balance --num-reserved-decode-tokens 400 --port 11000 --disaggregation-bootstrap-port 14000

# frontend
python3 -m dynamo.frontend --http-port 8001

# nginx
nginx -c /scratch/fsw/ishan/ignition/outputs/12486/logs/nginx.conf -g 'daemon off;'

Commands for radix 12485

Can we check the actual concurrency/batch of the decode nodes during the actual load test? In theory, the decode nodes should have larger batches due to sharing more prefix cache, but the ITL degrading this much is a bit beyond my expectations.

Comment thread python/sglang/srt/models/qwen3.py Outdated
Comment on lines +90 to +100
try:
for req in batch.reqs:
req.time_stats.set_decode_prebuilt_finish_time()
req.check_finished()
if req.finished():
req.time_stats.set_quick_finish_time()
release_kv_cache(req, self.tree_cache)

# Note: Logprobs should be handled on the prefill engine.
self.stream_output(batch.reqs, batch.return_logprob)
finally:
Copy link
Copy Markdown
Collaborator

@ShangmingCai ShangmingCai Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a try here, I think, with the protection of use_free_group, it should be safe when decode radix cache is not enabled.

So is this try added to protect the case when enabling decode radix cache?

Copy link
Copy Markdown
Collaborator Author

ishandhanani commented Apr 14, 2026

Follow-up on the decode batch / ITL question @yudian0504. I think this delta comes from rate matching (for a workload that leverages a lot of KV cache we do not need 3 prefill + 1 decode worker. In order to test this, I set the decode worker's max running requests to 64 to test.

run req/s out tok/s TTFT p50 ITL avg ITL p50 latency p50
baseline 12488 1.2312 437.58 72417 ms 66.46 ms 74.06 ms 98681 ms
radix mrr256 12487 1.6196 575.59 7297 ms 130.69 ms 166.91 ms 71844 ms
radix mrr64 12490 1.5485 550.34 39534 ms 79.04 ms 91.81 ms 73956 ms

The decode logs confirm the cap took effect: job 12490 held steady-state #running-req around avg/p50/max 59/60/64. The previous radix run reached max decode batch 126. Therefore I think a lot of the ITL variance comes from the larger decode bs. Happy to investigate more

Comment on lines +747 to +753
prefix_indices, prefix_len = self._match_prefix_and_lock(decode_req.req)
# Align prefix_len down to page boundary so both prefill and
# decode agree on the page-aligned split point for KV transfer.
page_size = self.token_to_kv_pool_allocator.page_size
if page_size > 1 and prefix_len % page_size != 0:
prefix_len = page_align_floor(prefix_len, page_size)
prefix_indices = prefix_indices[:prefix_len]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the returned indices from tree are always page-aligned?

Comment on lines +437 to +459
def _match_prefix_and_lock(self, req: Req) -> Tuple[torch.Tensor, int]:
"""
Match a request against the decode-side radix cache, lock the matched
node to prevent eviction, and return the matched prefix information.
"""
result = self.tree_cache.match_prefix(
MatchPrefixParams(
key=RadixKey(req.origin_input_ids, extra_key=req.extra_key),
req=req,
cow_mamba=self.tree_cache.supports_mamba(),
)
)
prefix_indices = result.device_indices
last_device_node = result.last_device_node
# Always lock to match aggregated scheduling behavior
self.tree_cache.inc_lock_ref(last_device_node)

# we do this to ensure that whenever dec_loc_ref is called
# on the Req object, we are not dereferencing a `None`. In the
# agg case, the scheduler does this already
req.last_node = last_device_node

return prefix_indices, len(prefix_indices)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have the same logic in schedule_policy.py. merge them

Comment on lines +454 to +456
def page_align_floor(length: int, page_size: int) -> int:
"""Round length down to the nearest page boundary."""
return (length // page_size) * page_size
Copy link
Copy Markdown
Collaborator

@cctry cctry Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is a more general function. we can move it out of disaggregation/

Comment on lines +387 to +391
if (
server_args.disaggregation_mode != "null"
and not server_args.disable_radix_cache
):
await _global_state.tokenizer_manager.flush_cache()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this flush_cache might fail because the new requests might come earlier than this command. a better solution is not to insert the cache for fake bootstrap host

Comment on lines +1983 to +2002
if (
server_args.disaggregation_mode != "null"
and not server_args.disable_radix_cache
):
try:
flush_res = requests.post(
url + "/flush_cache",
headers=headers,
timeout=30,
verify=ssl_verify,
)
if flush_res.status_code == 200:
logger.info("Flushed warmup cache")
else:
logger.warning(
f"Warmup cache flush failed: {flush_res.status_code}"
)
except Exception as e:
logger.warning(f"Warmup cache flush request failed: {e}")
logger.info("End of disaggregation warmup")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. we don't guarantee the warmup request will block following requests iirc

Comment on lines +813 to +819
kv_indices = (
self.req_to_token_pool.req_to_token[decode_req.req.req_pool_idx][
prefix_len:origin_input_len
]
.cpu()
.numpy()
)
Copy link
Copy Markdown
Collaborator

@cctry cctry Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to nixl/conn.py? this is specific to nixl's implementation of decode radix cache. there are other ways to resolve the delta and handle more complicated cases

allocatable_tokens = self._allocatable_tokens(
retractable_tokens=retractable_tokens,
count_retracted=True,
extra_reserved_reqs=len(preallocated_reqs) + 1,
Copy link
Copy Markdown
Collaborator

@cctry cctry Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need extra_reserved_reqs since those requests are already allocated

Comment on lines +963 to +973
def _required_alloc_tokens(self, *, fill_len: int, prefix_len: int) -> int:
page_size = self.token_to_kv_pool_allocator.page_size
if page_size == 1:
return fill_len - prefix_len

num_new_pages = get_num_new_pages(
seq_lens=torch.tensor([fill_len], dtype=torch.int64),
prefix_lens=torch.tensor([prefix_len], dtype=torch.int64),
page_size=page_size,
)
return num_new_pages * page_size
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a generall function. can be moved outside e.g. mem_cache/common.py

Comment on lines +1017 to +1027
if self.token_to_kv_pool_allocator.available_size() < required_alloc_tokens:
logger.warning(
f"Eviction insufficient: needed {required_alloc_tokens} tokens, "
f"available {self.token_to_kv_pool_allocator.available_size()} "
f"after evicting {result.num_tokens_evicted}/{num_to_evict} tokens. "
f"evictable_size={self.tree_cache.evictable_size()}, "
f"protected_size={self.tree_cache.protected_size()}, "
f"fill_len={fill_len}, prefix_len={prefix_len}, delta_len={delta_len}, "
f"page_size={self.token_to_kv_pool_allocator.page_size}, "
f"req={req.rid}"
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should crash if eviction fails?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Crash could be dangerous for production? If no memory leak happens, I think a warning should be fine.

Comment on lines +783 to +790
if end_idx < start_idx:
logger.debug(
"send_kv_chunk skip: rid=%s start_send_idx=%s end_idx=%s",
req.rid,
start_idx,
end_idx,
)
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will this happen?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When prefill cache hit len < decode cache hit len, so prefill will run some chunk that doesn't need to be sent because decode already has it. Since the meta changed the start_idx, it will happen.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. imo this should be handled as backend-specific logic. It is the backend to decide whether the chunks not needed by decode should be sent (e.g for numerical reasons)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.